from __future__ import division, print_function, absolute_import

import numpy as np
import pickle
import torch as t
from CDTools.tools.cmath import *
from RPI_tools import *

# So the plan is to take a pretty run-of-the-mill situation
# and then try running band-limited reconstructions at a whole
# bunch of resolutions


probe_maxk = 128
resolution_ratios = [0.3,0.35,0.4]

# number of attempts per point
attempts = 1000
# number of points to check
npoints = 21

# Set GPU to false if you don't have a GPU
GPU = True

# This defines the size of the reciprocal space array we will simulate the
# probe on.

probe_stage = int(probe_maxk) * 2
probe_fov_radius = (probe_stage * 4)//10

losses = []
errors = []
objs = []
all_resolutions = []
for resolution_ratio in resolution_ratios:

    # This generates our test probe
    nearfield = generate_blr_probe([probe_stage,probe_stage], probe_fov_radius, probe_maxk)
    
    # And we set up the array size for our original object
    obj_size = int(probe_maxk * resolution_ratio) * 2
    ec_padding =int(obj_size * 1/4)
    
    real = np.random.randn(obj_size, obj_size)
    imag = np.random.randn(obj_size, obj_size)
    original_image = complex_to_torch(real + 1j * imag).to(dtype=t.float32)
    objs.append(original_image)
    
    # Setting up the extra padding
    extra_ks = np.arange(0,probe_maxk,(probe_maxk)//npoints)
    resolutions = obj_size + 2*extra_ks
    all_resolutions.append(resolutions)
    
    res_losses = []
    res_errors = []
    for resolution in resolutions:    
        
        
        # Just for the expand_probe function
        newshape = np.zeros([resolution, resolution])
        
        # The expanded probe will have enough zeros to avoid aliasing when
        # used with an object of size newshape
        expanded_probe = expand_probe(nearfield,newshape)
        
        exit_wave = interact(original_image, expanded_probe)
        diffraction = propagate(exit_wave)
        intensities = measure(diffraction, expanded_probe.shape)
        
        # We normalize the pattern to 10 photons per pixel,
        # although no poisson noise is applied
        intensities /= t.sum(intensities)
        intensities *= 10 * probe_maxk**2 
        
        #plt.imshow(intensities)
        #plt.show()
        
        loop_losses = []
        loop_errors = []
        for i in range(attempts):
            # Each  retrieval uses a different randomized initialization
            retrieval, loss = reconstruct(intensities, expanded_probe, resolution, 0.4, 1000, GPU=GPU, schedule=True)
            
            ret = torch_to_complex(retrieval)
            im = torch_to_complex(original_image)
            
            #plt.imshow(np.abs(ret))
            #plt.figure()
            #plt.imshow(np.abs(torch_to_complex(interact(retrieval,expanded_probe))))
            #plt.figure()
            #plt.imshow(np.abs(im))
            #plt.show()
            
            # We calculate the RMS  error
            error, originals, comparisons = compare_images(ret, im, np.s_[ec_padding:-ec_padding,ec_padding:-ec_padding])
            
            loop_losses.append(loss[-1])
            loop_errors.append(error)
            print('Ratio',resolution_ratio,' and resolution', resolution, 'had loss',loss[-1], 'and error',  error)
            
        res_losses.append(loop_losses)
        res_errors.append(loop_errors)

    losses.append(res_losses)
    errors.append(res_errors)

results = {'probe': nearfield, 'losses':losses, 'resolutions': all_resolutions, 'errors':errors, 'obj': objs}

with open('data/bandlimiting usefulness trial 1000.pickle', 'wb') as f:
    pickle.dump(results, f)
